{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# This is a tutorial on using this library\n",
    "# first off we need a text_encoder so we would know our vocab_size (and later on use it to encode sentences)\n",
    "from data.vocab import SentencePieceTextEncoder  # you could also import OpenAITextEncoder\n",
    "\n",
    "sentence_piece_encoder = SentencePieceTextEncoder(text_corpus_address='openai/model/params_shapes.json',\n",
    "                                                  model_name='tutorial', vocab_size=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# now we need a sequence encoder\n",
    "from transformer.model import create_transformer\n",
    "\n",
    "sequence_encoder_config = {\n",
    "    'embedding_dim': 6,\n",
    "    'vocab_size': sentence_piece_encoder.vocab_size,\n",
    "    'max_len': 8,\n",
    "    'trainable_pos_embedding': False,\n",
    "    'num_heads': 2,\n",
    "    'num_layers': 3,\n",
    "    'd_hid': 12,\n",
    "    'use_attn_mask': True\n",
    "}\n",
    "sequence_encoder = create_transformer(**sequence_encoder_config)\n",
    "import keras\n",
    "\n",
    "assert type(sequence_encoder) == keras.Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# now look at the inputs:\n",
    "print(sequence_encoder.inputs)  # tokens, segment_ids, pos_ids, attn_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# tokens is a batch_size * seq_len tensor containing token_ids\n",
    "# segment_ids is a batch_size * seq_len tensor containing segment_ids (as in segment_{a, b} of BERT)\n",
    "# pos_ids is a batch_size * seq_len tensor containing position ids (0..max_len)(you will see how can easily generate it)\n",
    "# attn_mask is a batch_size * 1 * max_len * max_len tensor and can encode padding and causality constraints (ignore it for now)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# for outputs we have:\n",
    "print(sequence_encoder.outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 'a long name' is a batch_size * max_len * embedding_dim tensor which is our encoded sequence (here with a transformer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# now it's time to train it both on pre-training tasks and fine-tuning tasks\n",
    "# first we need to define our tasks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from data.dataset import TaskMetadata, TaskWeightScheduler\n",
    "\n",
    "tasks = [TaskMetadata('lm', is_token_level=True,\n",
    "                      num_classes=sentence_piece_encoder.vocab_size + sentence_piece_encoder.SPECIAL_COUNT,\n",
    "                      dropout=0,\n",
    "                      weight_scheduler=TaskWeightScheduler(active_in_pretrain=True, active_in_finetune=False,\n",
    "                                                           pretrain_value=1.0))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# well let's pause and see what this task is, 'lm' is the name of the task\n",
    "# and 'lm' is also a special task, because it uses a tied decoder (if you don't know what it means, ignore it)\n",
    "# then num_classes is set to vocab+special_count which is actually incorrect (we are never going to predict mask, pad, )\n",
    "# but it's here for the tied decoder to work; dropout is for the decoder of this task\n",
    "# and finally a weight_scheduler, in this example we are only training on 'lm' task during the pretraing but not after\n",
    "# now let's add a more complex task, a sentence level one with a complex weight_scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class ComplexTaskWeightScheduler(TaskWeightScheduler):  # note: this is an example, it is not a clean code\n",
    "    def __init__(self, number_of_pretrain_steps, number_of_finetune_steps):\n",
    "        super().__init__(active_in_pretrain=True, active_in_finetune=True)\n",
    "        self.number_of_pretrain_steps = number_of_pretrain_steps\n",
    "        self.number_of_finetune_steps = number_of_finetune_steps\n",
    "\n",
    "    def get(self, is_pretrain: bool, step: int) -> float:\n",
    "        return step / (self.number_of_pretrain_steps if is_pretrain else self.number_of_finetune_steps)\n",
    "\n",
    "\n",
    "number_of_pretrain_steps = 100\n",
    "number_of_finetune_steps = 100\n",
    "# in this task we are going to count the number of tokens in a sentence and predict if it's odd or not\n",
    "tasks.append(TaskMetadata('odd', is_token_level=False, num_classes=2, dropout=0.3,\n",
    "                          weight_scheduler=ComplexTaskWeightScheduler(number_of_pretrain_steps,\n",
    "                                                                      number_of_finetune_steps)))\n",
    "\n",
    "# and let's add a unsolvable task for fun\n",
    "tasks.append(TaskMetadata('lm_random', is_token_level=True,\n",
    "                          num_classes=sentence_piece_encoder.vocab_size + sentence_piece_encoder.SPECIAL_COUNT,\n",
    "                          dropout=0.3,\n",
    "                          weight_scheduler=TaskWeightScheduler(active_in_pretrain=True, active_in_finetune=True,\n",
    "                                                               pretrain_value=0.5)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# now we need a data generator, for a good reference see data.lm_dataset._get_lm_generator_single or _double\n",
    "# but for now we are going to write a simple one so you understand the Sentence class\n",
    "# again this is a simple generator just showing you the core ideas\n",
    "# so for 'lm' task we are just going to predict the token itself (identity function)\n",
    "# first we are importing things, ignore them for now, I will explain them in a bit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from data.dataset import Sentence, TokenTaskData, SentenceTaskData, TextEncoder\n",
    "from data.lm_dataset import _create_batch\n",
    "import random\n",
    "\n",
    "\n",
    "def tutorial_batch_generator(vocab_size: int, max_len: int, batch_size: int, steps: int):\n",
    "    def sentence_generator():\n",
    "        for _ in range(steps):\n",
    "            # for each sentence we are going to generate up to max_len tokens\n",
    "            seq_len = random.randint(1, max_len - 1)\n",
    "            # and this is their ids (in reality we have to use our TextEncoder instance here)\n",
    "            tokens = [random.randrange(vocab_size) for _ in range(seq_len)]\n",
    "            # we manually set the last token to EOS (which we will see how it's calculated)\n",
    "            tokens[-1] = eos_id\n",
    "            yield Sentence(\n",
    "                tokens=tokens,\n",
    "                padding_mask=[True] * seq_len,  # it means that non of the original tokens are padding\n",
    "                segments=[0] * seq_len,  # for this simple example we are going to use segment_a(0) for all of them\n",
    "                token_classification={  # we put labels here (for token level tasks)\n",
    "                    # name_of_the_task: TokenTaskData(target(aka label), label_mask)\n",
    "                    # there might be situations that you are only interested in predictions for certain tokens,\n",
    "                    # you can use mask in those situations (see the bert paper to understand this)\n",
    "                    'lm': TokenTaskData(tokens, [True] * seq_len),\n",
    "                    # this task is unsolvable so we will see the loss not decreasing\n",
    "                    'lm_random': TokenTaskData([random.randrange(vocab_size) for i in range(seq_len)],\n",
    "                                               [True] * seq_len)\n",
    "                },\n",
    "                # similar to token_classification, it's also a dictionary of task to label\n",
    "                # SentenceTaskData contains (label, where to extract that label_from)\n",
    "                # in this case we are going to predict whether a sentence has\n",
    "                # odd number of tokens or not whenever we see eos token\n",
    "                sentence_classification={'odd': SentenceTaskData(seq_len % 2, seq_len - 1)}\n",
    "            )\n",
    "\n",
    "    # we need eos_id and it's always at this place\n",
    "    eos_id = vocab_size + TextEncoder.EOS_OFFSET\n",
    "    # likewise for pad_id\n",
    "    pad_id = vocab_size + TextEncoder.PAD_OFFSET\n",
    "    generator = sentence_generator()\n",
    "    batch = []\n",
    "    for item in generator:\n",
    "        batch.append(item)\n",
    "        if len(batch) == batch_size:\n",
    "            batch = _create_batch(batch, pad_id, max_len)  # magic to pad and batch sentences\n",
    "            # at the end it will generate a SentenceBatch which is more than just a list of Sentence\n",
    "            yield batch\n",
    "            batch = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# now we instantiate our generator\n",
    "# we are going to set steps to a large number (it doesn't matter)\n",
    "# we have to set batch_size too"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "batch_size = 5\n",
    "generator = tutorial_batch_generator(sentence_piece_encoder.vocab_size, sequence_encoder_config['max_len'],\n",
    "                                     batch_size, steps=10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# now let the fun begin :D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from transformer.train import train_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# we are going to use the same generator for both pretrain and finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "m = train_model(base_model=sequence_encoder, is_causal=False, tasks_meta_data=tasks, pretrain_generator=generator,\n",
    "                finetune_generator=generator, pretrain_epochs=100, pretrain_steps=number_of_pretrain_steps // 100,\n",
    "                finetune_epochs=100, finetune_steps=number_of_finetune_steps // 100, verbose=2)\n",
    "# now m is ready to be used!\n",
    "print(m.inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# token, segment, pos, att_mask, odd_mask (where to extract the class from)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "bs = 6\n",
    "vs = sentence_piece_encoder.vocab_size\n",
    "sl = sequence_encoder_config['max_len']\n",
    "# generate random tokens\n",
    "token = np.random.randint(0, vs, (bs, sl))\n",
    "# generate random seg_id\n",
    "segment = np.random.randint(0, 2, (bs, sl))\n",
    "# generate pos_id\n",
    "from transformer.train import generate_pos_ids\n",
    "\n",
    "pos = generate_pos_ids(bs, sl)\n",
    "# generate attn_mask\n",
    "from data.dataset import create_attention_mask\n",
    "\n",
    "# first generate a padding_mask(1 means it is not padded)\n",
    "pad_mask = np.random.randint(0, 2, (bs, sl)).astype(np.int8)\n",
    "# create the mask\n",
    "mask = create_attention_mask(pad_mask=pad_mask, is_causal=False)\n",
    "# generate target index\n",
    "target_index = np.random.randint(0, sl, (bs, 1))\n",
    "res = m.predict([token, segment, pos, mask, target_index], verbose=2)\n",
    "assert res[0].shape == (bs, sl, vs + TextEncoder.SPECIAL_COUNT)  # lm\n",
    "assert res[1].shape == (bs, 1, 2)  # odd\n",
    "assert res[2].shape == (bs, sl, vs + TextEncoder.SPECIAL_COUNT)  # random_lm"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
